import cv2
import torchvision.transforms as transforms
from Models.Yolo import YOLO
import argparse
import torch
from Config.ConfigTrain import *
import numpy as np
from PIL import Image,ImageDraw,ImageFont

def iou(box_one, box_two):
    LX = max(box_one[0], box_two[0])
    LY = max(box_one[1], box_two[1])
    RX = min(box_one[2], box_two[2])
    RY = min(box_one[3], box_two[3])
    if LX >= RX or LY >= RY:
        return 0
    return (RX - LX) * (RY - LY) / ((box_one[2]-box_one[0]) * (box_one[3] - box_one[1]) + (box_two[2]-box_two[0]) * (box_two[3] - box_two[1]))


def NMS(bounding_boxes,S=7,B=2,img_size=448,confidence_threshold=0.5,iou_threshold=0.0,possible_pred=0.4):
    bounding_boxes = bounding_boxes.cpu().detach().numpy().tolist()
    predict_boxes = []
    nms_boxes = []
    grid_size = img_size / S
    for batch in range(len(bounding_boxes)):
        for i in range(S):
            for j in range(S):
                gridX = grid_size * j
                gridY = grid_size * i
                if bounding_boxes[batch][i][j][4] < bounding_boxes[batch][i][j][9]:
                    bounding_box = bounding_boxes[batch][i][j][5:10]
                else:
                    bounding_box = bounding_boxes[batch][i][j][0:5]
                class_possible = (bounding_boxes[batch][i][j][10:])
                bounding_box.extend(class_possible)
                possible = max(class_possible)
                if (bounding_box[4] < confidence_threshold

                    ):
                    continue

                if(bounding_box[4]*possible < possible_pred):
                    continue
                # print(bounding_box[4]*possible)

                centerX = (int)(gridX + bounding_box[0] * grid_size)
                centerY = (int)(gridY + bounding_box[1] * grid_size)
                width = (int)(bounding_box[2] * img_size)
                height = (int)(bounding_box[3] * img_size)
                bounding_box[0] = max(0, (int)(centerX - width / 2))
                bounding_box[1] = max(0, (int)(centerY - height / 2))
                bounding_box[2] = min(img_size - 1, (int)(centerX + width / 2))
                bounding_box[3] = min(img_size - 1, (int)(centerY + height / 2))
                predict_boxes.append(bounding_box)

        while len(predict_boxes) != 0:
            predict_boxes.sort(key=lambda box:box[4])
            assured_box = predict_boxes[0]
            temp = []
            classIndex = np.argmax(assured_box[5:])
            #print("类别:{}".format(ClassesName[classIndex))
            assured_box[4] = assured_box[4] * assured_box[5 + classIndex]
            #修正置信度为 物体分类准确度 × 含有物体的置信度
            assured_box[5] = classIndex
            nms_boxes.append(assured_box)
            i = 1
            while i < len(predict_boxes):
                if iou(assured_box,predict_boxes[i]) <= iou_threshold:
                    temp.append(predict_boxes[i])
                i = i + 1
            predict_boxes = temp

        return nms_boxes

def detect():
    transform = transforms.Compose([
        transforms.ToTensor(),  # height * width * channel -> channel * height * width
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    image_dir = opt.valid_dir
    img_data = cv2.imread(image_dir)
    img_data = cv2.resize(img_data, (448, 448), interpolation=cv2.INTER_AREA)
    train_data = transform(img_data)
    train_data = train_data.unsqueeze(0)

    net = YOLO(B=2,classes_num=Classes)
    state_dict_load = torch.load(opt.path_state_dict)
    net.load_state_dict(state_dict_load)

    net.eval()
    with torch.no_grad():
        bounding_boxes = net(train_data)

    NMS_boxes = NMS(bounding_boxes,confidence_threshold=opt.confidence,iou_threshold=opt.iou,possible_pred=opt.possible_pre)

    """
    如果要拆的话把这个封装好就完了
    """

    font = ImageFont.truetype(r'font/simsun.ttc', 20, encoding='utf-8')

    for box in NMS_boxes:

        img_data = cv2.rectangle(img_data, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 1)
        """
         处理中文
        """
        pil_img = Image.fromarray(cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB))
        draw = ImageDraw.Draw(pil_img)

        draw.text((box[0], box[1]),"{}:{}".format(ClassesName[box[5]], round(box[4], 2)),(148,175,100),font)

        print("class_name:{} confidence:{}".format(ClassesName[int(box[5])],round(box[4],2)))

        img_data = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)

    if(opt.show_img):

        cv2.imshow("img_detection", img_data)
        cv2.waitKey()
        cv2.destroyAllWindows()
    if(opt.save_dir):
        cv2.imwrite(opt.save_dir, img_data)




if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--valid_dir', type=str, default=r'F:\projects\PythonProject\HuLook\Data\DetData\train\images\002.jpg')
    parser.add_argument('--path_state_dict', type=str, default=r'F:\projects\PythonProject\HuLook\runs\traindetect\epx0\weights\best.pth')
    parser.add_argument("--iou",type=float,default=0.2)
    parser.add_argument("--confidence",type=float,default=0.5)
    parser.add_argument("--possible_pre",type=float,default=0.35)
    parser.add_argument("--show_img",type=bool,default=True)
    parser.add_argument("--save_dir",type=str,default="")
    opt = parser.parse_args()
    detect()
